package org.tio.showcase.websocket.server.processor;

import cn.hutool.core.lang.Snowflake;
import cn.hutool.core.thread.NamedThreadFactory;
import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
import org.tio.http.common.HttpRequest;
import org.tio.http.common.HttpResponse;
import org.tio.showcase.websocket.server.Const;
import org.tio.showcase.websocket.server.pojo.Msg;
import org.tio.showcase.websocket.server.pojo.User;
import org.tio.showcase.websocket.server.util.MsgUtil;
import org.tio.websocket.common.WsRequest;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.JedisPubSub;

import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * 基于jedis的发布订阅器
 *
 * @author huanglin
 */
public class ServerProcessorOnPubSub implements ServerProcessor {

    private static Logger log = LoggerFactory.getLogger(ServerProcessorOnPubSub.class);

    private JedisPool jedisPool;

    private Snowflake snowflake;

    public ServerProcessorOnPubSub() {
        //这里可以根据配置来设置节点id，如果跨机房的话也可以单独设置数据中心的id
        snowflake = new Snowflake(0, 0);
        JedisPoolConfig poolConfig = new JedisPoolConfig();
        poolConfig.setMaxIdle(10);
        poolConfig.setMaxTotal(100);
        poolConfig.setMaxWaitMillis(5000L);
        poolConfig.setTestOnBorrow(true);
        jedisPool = new JedisPool(poolConfig, "10.1.5.16", 6379, 5000);
        subscribeMsg();

    }

    @Override
    public void onAfterHandshaked(HttpRequest httpRequest, HttpResponse httpResponse, ChannelContext channelContext) throws Exception {
        String username = channelContext.getUserid();
        //TODO 如查询当前用户所在组的功能
        //Set<String> groups = userService.getUserGroups(username);
        // for 循环 ：Aio.bindGroup(channelContext, group);
        //不管之前是否已经登录，直接覆盖，实际业务时会有具体处理
        User user = new User();
        // user.setGroup(groups);
        user.setUsername(username);
        user.setNode(channelContext.getServerNode().toString());
        JedisOperatorHelper.set(jedisPool, Const.WS_USER_PREFIX + username, JSON.toJSONString(user));
        JedisOperatorHelper.addOnlieUser(jedisPool, username);
        log.info("用户{}加入，当前总用户{}", username, JedisOperatorHelper.countOnlineUser(jedisPool));
    }

    @Override
    public void onBeforeClose(ChannelContext channelContext, Throwable throwable, String remark, boolean isRemove) throws Exception {
        String username = channelContext.getUserid();
        JedisOperatorHelper.del(jedisPool, Const.WS_USER_PREFIX + username);
        JedisOperatorHelper.delOnlineUser(jedisPool, username);
        log.info("用户{}离开，当前总用户{}", username, JedisOperatorHelper.countOnlineUser(jedisPool));
    }

    @Override
    public Object onText(WsRequest wsRequest, String text, ChannelContext channelContext) throws Exception {
        Msg msg = JSON.parseObject(text, Msg.class);
        //心跳信息则不用理
        if (Const.Action.HEART_BEAT_REQ.val() != msg.getAction()) {
            msg.setId(String.valueOf(snowflake.nextId()));
            //目标在本实例节点的直接发
            if (MsgUtil.existsUser(msg.getTo())) {
                processMsg(false, msg);
            } else {
                JedisOperatorHelper.publish(jedisPool, JSON.toJSONString(msg), Const.WS_MSG_TOPIC_CHANNEL);
            }
        }
        return null;
    }

    private void subscribeMsg() {
        ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 10, TimeUnit.SECONDS, new LinkedBlockingDeque<>(), new NamedThreadFactory("ws_msg_subscribe", false));
        executor.execute(() -> {
                    Jedis jedis = jedisPool.getResource();
                    jedis.subscribe(new JedisPubSub() {
                        @Override
                        public void onMessage(String channel, String message) {
                            processMsg(true, JSON.parseObject(message, Msg.class));
                        }
                    }, Const.WS_MSG_TOPIC_CHANNEL);
                }
        );
        log.info("订阅通道完成");
    }

    private void processMsg(boolean isPublish, Msg msg) {
        int action = msg.getAction();
        Msg respMsg = new Msg();
        //响应信息则直接返回给客户端即可
        if (action % 11 == 0 && MsgUtil.existsUser(msg.getTo())) {
            //重新包装下后再发送
            respMsg.setMsg(msg.getMsg());
            respMsg.setAction(msg.getAction());
            respMsg.setStatus(msg.getStatus());
            respMsg.setId(msg.getId());
            MsgUtil.sendToUser(msg.getTo(), respMsg);
        } else {
            respMsg.setTo(msg.getFrom());
            respMsg.setStatus("200");
            respMsg.setId(msg.getId());
            if (action == Const.Action.P2P_MSG_REQ.val()) {
                respMsg.setAction(Const.Action.P2P_MSG_RESP.val());
                if (MsgUtil.existsUser(msg.getTo())) {
                    MsgUtil.sendToUser(msg.getTo(), msg);
                    if (isPublish) {
                        JedisOperatorHelper.publish(jedisPool, JSON.toJSONString(respMsg), Const.WS_MSG_TOPIC_CHANNEL);
                    } else {
                        MsgUtil.sendToUser(msg.getFrom(), respMsg);
                    }
                }
            } else if (action == Const.Action.GROUP_MSG_REQ.val()) {
                MsgUtil.sendToGroup(msg.getTo(), msg);
                respMsg.setAction(Const.Action.GROUP_MSG_RESP.val());
                if (isPublish) {
                    JedisOperatorHelper.publish(jedisPool, JSON.toJSONString(respMsg), Const.WS_MSG_TOPIC_CHANNEL);
                } else {
                    MsgUtil.sendToUser(msg.getFrom(), respMsg);
                }
            }
        }
    }

    static class JedisOperatorHelper {
        public static void set(JedisPool pool, String key, String value) {
            try (Jedis jedis = pool.getResource()) {
                jedis.set(key, value);
            } catch (Exception e) {
                log.error("", e);
            }
        }

        public static void del(JedisPool pool, String key) {
            try (Jedis jedis = pool.getResource()) {
                jedis.del(key);
            } catch (Exception e) {
                log.error("", e);
            }
        }

        public static void publish(JedisPool pool, String message, String channel) {
            try (Jedis jedis = pool.getResource()) {
                jedis.publish(channel, message);
            } catch (Exception e) {
                log.error("", e);
            }
        }

        public static void addOnlieUser(JedisPool pool, String username) {
            try (Jedis jedis = pool.getResource()) {
                jedis.sadd(Const.WS_USER_ONLINE, username);
            } catch (Exception e) {
                log.error("", e);
            }
        }

        public static void delOnlineUser(JedisPool pool, String username) {
            try (Jedis jedis = pool.getResource()) {
                jedis.srem(Const.WS_USER_ONLINE, username);
            } catch (Exception e) {
                log.error("", e);
            }
        }

        public static long countOnlineUser(JedisPool pool) {
            try (Jedis jedis = pool.getResource()) {
                return jedis.scard(Const.WS_USER_ONLINE);
            } catch (Exception e) {
                log.error("", e);
            }
            return 0;
        }
    }
}
